{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zkufh760uvF3"
   },
   "source": [
    "![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/nlu/blob/master/examples/colab/Training/binary_text_classification/NLU_training_sentiment_classifier_demo.ipynb)\n",
    "\n",
    "\n",
    "\n",
    "# Training a Sentiment Analysis Classifier with NLU\n",
    "With the [ClassifierDL model](https://nlp.johnsnowlabs.com/docs/en/annotators#classifierdl-multi-class-text-classification) from Spark NLP you can achieve State Of the Art results on any multi class text classification problem\n",
    "\n",
    "This notebook showcases the following features :\n",
    "\n",
    "- How to train the deep learning classifier\n",
    "- How to store a pipeline to disk\n",
    "- How to load the pipeline from disk (Enables NLU offline mode)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dur2drhW5Rvi"
   },
   "source": [
    "# 1. Install Java 8 and NLU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "hFGnBCHavltY"
   },
   "outputs": [],
   "source": [
    "!pip install -q johnsnowlabs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "f4KkTfnR5Ugg"
   },
   "source": [
    "# 2. Download Stock Market Sentiment dataset\n",
    "https://www.kaggle.com/yash612/stockmarket-sentiment-dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "OrVb5ZMvvrQD"
   },
   "outputs": [],
   "source": [
    "! wget https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/stock_data/stock_data.csv\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "uDGIOASY_fRj",
    "outputId": "bdd53b41-7b1e-47e8-9783-b4d8e3315d84"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning::Spark Session already created, some configs may not take.\n",
      "sentimentdl_glove_imdb download started this may take some time.\n",
      "Approximate size to download 8.7 MB\n",
      "[OK!]\n",
      "glove_100d download started this may take some time.\n",
      "Approximate size to download 145.3 MB\n",
      "[OK!]\n"
     ]
    }
   ],
   "source": [
    "from johnsnowlabs import nlp\n",
    "sentiment = nlp.load('sentiment')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 150
    },
    "id": "U0ENiuMc_kyb",
    "outputId": "5d5d2dc8-7481-468c-ff8c-0c45fdc0b0c7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sentence_detector_dl download started this may take some time.\n",
      "Approximate size to download 354.6 KB\n",
      "[OK!]\n",
      "Warning::Spark Session already created, some configs may not take.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "  <div id=\"df-152d5537-93a5-4c1d-8e35-4f63cb2f164f\" class=\"colab-df-container\">\n",
       "    <div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sentence</th>\n",
       "      <th>sentence_embedding_converter</th>\n",
       "      <th>sentiment</th>\n",
       "      <th>sentiment_confidence</th>\n",
       "      <th>word_embedding_glove</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>I'm very very not at all happy</td>\n",
       "      <td>[-0.2865465581417084, 0.25398728251457214, 0.2...</td>\n",
       "      <td>pos</td>\n",
       "      <td>0.999995</td>\n",
       "      <td>[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>\n",
       "    <div class=\"colab-df-buttons\">\n",
       "\n",
       "  <div class=\"colab-df-container\">\n",
       "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-152d5537-93a5-4c1d-8e35-4f63cb2f164f')\"\n",
       "            title=\"Convert this dataframe to an interactive table.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
       "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
       "  </svg>\n",
       "    </button>\n",
       "\n",
       "  <style>\n",
       "    .colab-df-container {\n",
       "      display:flex;\n",
       "      gap: 12px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert {\n",
       "      background-color: #E8F0FE;\n",
       "      border: none;\n",
       "      border-radius: 50%;\n",
       "      cursor: pointer;\n",
       "      display: none;\n",
       "      fill: #1967D2;\n",
       "      height: 32px;\n",
       "      padding: 0 0 0 0;\n",
       "      width: 32px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert:hover {\n",
       "      background-color: #E2EBFA;\n",
       "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "      fill: #174EA6;\n",
       "    }\n",
       "\n",
       "    .colab-df-buttons div {\n",
       "      margin-bottom: 4px;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert {\n",
       "      background-color: #3B4455;\n",
       "      fill: #D2E3FC;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert:hover {\n",
       "      background-color: #434B5C;\n",
       "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "      fill: #FFFFFF;\n",
       "    }\n",
       "  </style>\n",
       "\n",
       "    <script>\n",
       "      const buttonEl =\n",
       "        document.querySelector('#df-152d5537-93a5-4c1d-8e35-4f63cb2f164f button.colab-df-convert');\n",
       "      buttonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "\n",
       "      async function convertToInteractive(key) {\n",
       "        const element = document.querySelector('#df-152d5537-93a5-4c1d-8e35-4f63cb2f164f');\n",
       "        const dataTable =\n",
       "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
       "                                                    [key], {});\n",
       "        if (!dataTable) return;\n",
       "\n",
       "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
       "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
       "          + ' to learn more about interactive tables.';\n",
       "        element.innerHTML = '';\n",
       "        dataTable['output_type'] = 'display_data';\n",
       "        await google.colab.output.renderOutput(dataTable, element);\n",
       "        const docLink = document.createElement('div');\n",
       "        docLink.innerHTML = docLinkHtml;\n",
       "        element.appendChild(docLink);\n",
       "      }\n",
       "    </script>\n",
       "  </div>\n",
       "\n",
       "    </div>\n",
       "  </div>\n"
      ],
      "text/plain": [
       "                         sentence  \\\n",
       "0  I'm very very not at all happy   \n",
       "\n",
       "                        sentence_embedding_converter sentiment  \\\n",
       "0  [-0.2865465581417084, 0.25398728251457214, 0.2...       pos   \n",
       "\n",
       "  sentiment_confidence                               word_embedding_glove  \n",
       "0             0.999995  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentiment.predict(\"I'm very very not at all happy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 424
    },
    "id": "y4xSRWIhwT28",
    "outputId": "c486e2d0-708b-4519-b43c-7427900236cf"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "  <div id=\"df-6174d74d-0972-4526-b0af-127c2df2529d\" class=\"colab-df-container\">\n",
       "    <div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>text</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Kickers on my watchlist XIDE TIT SOQ PNK CPW B...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>user: AAP MOVIE. 55% return for the FEA/GEED i...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>user I'd be afraid to short AMZN - they are lo...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>MNTA Over 12.00</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>OI  Over 21.37</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5786</th>\n",
       "      <td>Industry body CII said #discoms are likely to ...</td>\n",
       "      <td>negative</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5787</th>\n",
       "      <td>#Gold prices slip below Rs 46,000 as #investor...</td>\n",
       "      <td>negative</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5788</th>\n",
       "      <td>Workers at Bajaj Auto have agreed to a 10% wag...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5789</th>\n",
       "      <td>#Sharemarket LIVE: Sensex off day’s high, up 6...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5790</th>\n",
       "      <td>#Sensex, #Nifty climb off day's highs, still u...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5791 rows × 2 columns</p>\n",
       "</div>\n",
       "    <div class=\"colab-df-buttons\">\n",
       "\n",
       "  <div class=\"colab-df-container\">\n",
       "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-6174d74d-0972-4526-b0af-127c2df2529d')\"\n",
       "            title=\"Convert this dataframe to an interactive table.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
       "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
       "  </svg>\n",
       "    </button>\n",
       "\n",
       "  <style>\n",
       "    .colab-df-container {\n",
       "      display:flex;\n",
       "      gap: 12px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert {\n",
       "      background-color: #E8F0FE;\n",
       "      border: none;\n",
       "      border-radius: 50%;\n",
       "      cursor: pointer;\n",
       "      display: none;\n",
       "      fill: #1967D2;\n",
       "      height: 32px;\n",
       "      padding: 0 0 0 0;\n",
       "      width: 32px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert:hover {\n",
       "      background-color: #E2EBFA;\n",
       "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "      fill: #174EA6;\n",
       "    }\n",
       "\n",
       "    .colab-df-buttons div {\n",
       "      margin-bottom: 4px;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert {\n",
       "      background-color: #3B4455;\n",
       "      fill: #D2E3FC;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert:hover {\n",
       "      background-color: #434B5C;\n",
       "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "      fill: #FFFFFF;\n",
       "    }\n",
       "  </style>\n",
       "\n",
       "    <script>\n",
       "      const buttonEl =\n",
       "        document.querySelector('#df-6174d74d-0972-4526-b0af-127c2df2529d button.colab-df-convert');\n",
       "      buttonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "\n",
       "      async function convertToInteractive(key) {\n",
       "        const element = document.querySelector('#df-6174d74d-0972-4526-b0af-127c2df2529d');\n",
       "        const dataTable =\n",
       "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
       "                                                    [key], {});\n",
       "        if (!dataTable) return;\n",
       "\n",
       "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
       "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
       "          + ' to learn more about interactive tables.';\n",
       "        element.innerHTML = '';\n",
       "        dataTable['output_type'] = 'display_data';\n",
       "        await google.colab.output.renderOutput(dataTable, element);\n",
       "        const docLink = document.createElement('div');\n",
       "        docLink.innerHTML = docLinkHtml;\n",
       "        element.appendChild(docLink);\n",
       "      }\n",
       "    </script>\n",
       "  </div>\n",
       "\n",
       "\n",
       "<div id=\"df-6ead1b5b-5d89-40df-b401-0b518dc2f02f\">\n",
       "  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-6ead1b5b-5d89-40df-b401-0b518dc2f02f')\"\n",
       "            title=\"Suggest charts.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
       "     width=\"24px\">\n",
       "    <g>\n",
       "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
       "    </g>\n",
       "</svg>\n",
       "  </button>\n",
       "\n",
       "<style>\n",
       "  .colab-df-quickchart {\n",
       "      --bg-color: #E8F0FE;\n",
       "      --fill-color: #1967D2;\n",
       "      --hover-bg-color: #E2EBFA;\n",
       "      --hover-fill-color: #174EA6;\n",
       "      --disabled-fill-color: #AAA;\n",
       "      --disabled-bg-color: #DDD;\n",
       "  }\n",
       "\n",
       "  [theme=dark] .colab-df-quickchart {\n",
       "      --bg-color: #3B4455;\n",
       "      --fill-color: #D2E3FC;\n",
       "      --hover-bg-color: #434B5C;\n",
       "      --hover-fill-color: #FFFFFF;\n",
       "      --disabled-bg-color: #3B4455;\n",
       "      --disabled-fill-color: #666;\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart {\n",
       "    background-color: var(--bg-color);\n",
       "    border: none;\n",
       "    border-radius: 50%;\n",
       "    cursor: pointer;\n",
       "    display: none;\n",
       "    fill: var(--fill-color);\n",
       "    height: 32px;\n",
       "    padding: 0;\n",
       "    width: 32px;\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart:hover {\n",
       "    background-color: var(--hover-bg-color);\n",
       "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "    fill: var(--button-hover-fill-color);\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart-complete:disabled,\n",
       "  .colab-df-quickchart-complete:disabled:hover {\n",
       "    background-color: var(--disabled-bg-color);\n",
       "    fill: var(--disabled-fill-color);\n",
       "    box-shadow: none;\n",
       "  }\n",
       "\n",
       "  .colab-df-spinner {\n",
       "    border: 2px solid var(--fill-color);\n",
       "    border-color: transparent;\n",
       "    border-bottom-color: var(--fill-color);\n",
       "    animation:\n",
       "      spin 1s steps(1) infinite;\n",
       "  }\n",
       "\n",
       "  @keyframes spin {\n",
       "    0% {\n",
       "      border-color: transparent;\n",
       "      border-bottom-color: var(--fill-color);\n",
       "      border-left-color: var(--fill-color);\n",
       "    }\n",
       "    20% {\n",
       "      border-color: transparent;\n",
       "      border-left-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "    }\n",
       "    30% {\n",
       "      border-color: transparent;\n",
       "      border-left-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "      border-right-color: var(--fill-color);\n",
       "    }\n",
       "    40% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "    }\n",
       "    60% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "    }\n",
       "    80% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "      border-bottom-color: var(--fill-color);\n",
       "    }\n",
       "    90% {\n",
       "      border-color: transparent;\n",
       "      border-bottom-color: var(--fill-color);\n",
       "    }\n",
       "  }\n",
       "</style>\n",
       "\n",
       "  <script>\n",
       "    async function quickchart(key) {\n",
       "      const quickchartButtonEl =\n",
       "        document.querySelector('#' + key + ' button');\n",
       "      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
       "      quickchartButtonEl.classList.add('colab-df-spinner');\n",
       "      try {\n",
       "        const charts = await google.colab.kernel.invokeFunction(\n",
       "            'suggestCharts', [key], {});\n",
       "      } catch (error) {\n",
       "        console.error('Error during call to suggestCharts:', error);\n",
       "      }\n",
       "      quickchartButtonEl.classList.remove('colab-df-spinner');\n",
       "      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
       "    }\n",
       "    (() => {\n",
       "      let quickchartButtonEl =\n",
       "        document.querySelector('#df-6ead1b5b-5d89-40df-b401-0b518dc2f02f button');\n",
       "      quickchartButtonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "    })();\n",
       "  </script>\n",
       "</div>\n",
       "    </div>\n",
       "  </div>\n"
      ],
      "text/plain": [
       "                                                   text         y\n",
       "0     Kickers on my watchlist XIDE TIT SOQ PNK CPW B...  positive\n",
       "1     user: AAP MOVIE. 55% return for the FEA/GEED i...  positive\n",
       "2     user I'd be afraid to short AMZN - they are lo...  positive\n",
       "3                                     MNTA Over 12.00    positive\n",
       "4                                      OI  Over 21.37    positive\n",
       "...                                                 ...       ...\n",
       "5786  Industry body CII said #discoms are likely to ...  negative\n",
       "5787  #Gold prices slip below Rs 46,000 as #investor...  negative\n",
       "5788  Workers at Bajaj Auto have agreed to a 10% wag...  positive\n",
       "5789  #Sharemarket LIVE: Sensex off day’s high, up 6...  positive\n",
       "5790  #Sensex, #Nifty climb off day's highs, still u...  positive\n",
       "\n",
       "[5791 rows x 2 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "train_path = '/content/stock_data.csv'\n",
    "\n",
    "train_df = pd.read_csv(train_path)\n",
    "# the text data to use for classification should be in a column named 'text'\n",
    "# the label column must have name 'y' name be of type str\n",
    "train_df.columns=['text','y']\n",
    "train_df.y = train_df.y.astype(str)\n",
    "train_df.y = train_df.y.str.replace('-1','negative')\n",
    "train_df.y = train_df.y.str.replace('1','positive')\n",
    "train_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0296Om2C5anY"
   },
   "source": [
    "# 3. Train Deep Learning Classifier using nlu.load('train.sentiment')\n",
    "\n",
    "You dataset label column should be named 'y' and the feature column with text data should be named 'text'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 806
    },
    "id": "3ZIPkRkWftBG",
    "outputId": "91874f31-7e1a-4180-c354-5317930d5c4c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning::Spark Session already created, some configs may not take.\n",
      "Warning::Spark Session already created, some configs may not take.\n",
      "sent_small_bert_L2_128 download started this may take some time.\n",
      "Approximate size to download 16.1 MB\n",
      "[OK!]\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "    negative       0.00      0.00      0.00      2106\n",
      "    positive       0.64      1.00      0.78      3685\n",
      "\n",
      "    accuracy                           0.64      5791\n",
      "   macro avg       0.32      0.50      0.39      5791\n",
      "weighted avg       0.40      0.64      0.49      5791\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "  <div id=\"df-963f7248-6cae-476a-884c-a98ad642ef82\" class=\"colab-df-container\">\n",
       "    <div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>document</th>\n",
       "      <th>sentence_embedding_small_bert_L2_128</th>\n",
       "      <th>sentiment</th>\n",
       "      <th>sentiment_confidence</th>\n",
       "      <th>text</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Kickers on my watchlist XIDE TIT SOQ PNK CPW B...</td>\n",
       "      <td>[-0.9530864357948303, 0.2135828286409378, 0.10...</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>Kickers on my watchlist XIDE TIT SOQ PNK CPW B...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>user: AAP MOVIE. 55% return for the FEA/GEED i...</td>\n",
       "      <td>[-0.4725969433784485, 0.5354134440422058, -0.2...</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>user: AAP MOVIE. 55% return for the FEA/GEED i...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>user I'd be afraid to short AMZN - they are lo...</td>\n",
       "      <td>[0.30400288105010986, 0.22862982749938965, -0....</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>user I'd be afraid to short AMZN - they are lo...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>MNTA Over 12.00</td>\n",
       "      <td>[-1.707902193069458, -0.48472753167152405, -0....</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>MNTA Over 12.00</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>OI Over 21.37</td>\n",
       "      <td>[-2.3011534214019775, 0.2649511396884918, -0.4...</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>OI  Over 21.37</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5786</th>\n",
       "      <td>Industry body CII said #discoms are likely to ...</td>\n",
       "      <td>[-0.21655204892158508, 0.6153537631034851, 0.0...</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>Industry body CII said #discoms are likely to ...</td>\n",
       "      <td>negative</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5787</th>\n",
       "      <td>#Gold prices slip below Rs 46,000 as #investor...</td>\n",
       "      <td>[-0.19915254414081573, 0.2607441842556, 0.0032...</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>#Gold prices slip below Rs 46,000 as #investor...</td>\n",
       "      <td>negative</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5788</th>\n",
       "      <td>Workers at Bajaj Auto have agreed to a 10% wag...</td>\n",
       "      <td>[-0.4361518919467926, 0.9346759915351868, -0.3...</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>Workers at Bajaj Auto have agreed to a 10% wag...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5789</th>\n",
       "      <td>#Sharemarket LIVE: Sensex off day’s high, up 6...</td>\n",
       "      <td>[-0.6081278920173645, 0.2732301354408264, 0.25...</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>#Sharemarket LIVE: Sensex off day’s high, up 6...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5790</th>\n",
       "      <td>#Sensex, #Nifty climb off day's highs, still u...</td>\n",
       "      <td>[-0.5274896621704102, 0.4326432943344116, 0.06...</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>#Sensex, #Nifty climb off day's highs, still u...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5791 rows × 6 columns</p>\n",
       "</div>\n",
       "    <div class=\"colab-df-buttons\">\n",
       "\n",
       "  <div class=\"colab-df-container\">\n",
       "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-963f7248-6cae-476a-884c-a98ad642ef82')\"\n",
       "            title=\"Convert this dataframe to an interactive table.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
       "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
       "  </svg>\n",
       "    </button>\n",
       "\n",
       "  <style>\n",
       "    .colab-df-container {\n",
       "      display:flex;\n",
       "      gap: 12px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert {\n",
       "      background-color: #E8F0FE;\n",
       "      border: none;\n",
       "      border-radius: 50%;\n",
       "      cursor: pointer;\n",
       "      display: none;\n",
       "      fill: #1967D2;\n",
       "      height: 32px;\n",
       "      padding: 0 0 0 0;\n",
       "      width: 32px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert:hover {\n",
       "      background-color: #E2EBFA;\n",
       "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "      fill: #174EA6;\n",
       "    }\n",
       "\n",
       "    .colab-df-buttons div {\n",
       "      margin-bottom: 4px;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert {\n",
       "      background-color: #3B4455;\n",
       "      fill: #D2E3FC;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert:hover {\n",
       "      background-color: #434B5C;\n",
       "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "      fill: #FFFFFF;\n",
       "    }\n",
       "  </style>\n",
       "\n",
       "    <script>\n",
       "      const buttonEl =\n",
       "        document.querySelector('#df-963f7248-6cae-476a-884c-a98ad642ef82 button.colab-df-convert');\n",
       "      buttonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "\n",
       "      async function convertToInteractive(key) {\n",
       "        const element = document.querySelector('#df-963f7248-6cae-476a-884c-a98ad642ef82');\n",
       "        const dataTable =\n",
       "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
       "                                                    [key], {});\n",
       "        if (!dataTable) return;\n",
       "\n",
       "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
       "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
       "          + ' to learn more about interactive tables.';\n",
       "        element.innerHTML = '';\n",
       "        dataTable['output_type'] = 'display_data';\n",
       "        await google.colab.output.renderOutput(dataTable, element);\n",
       "        const docLink = document.createElement('div');\n",
       "        docLink.innerHTML = docLinkHtml;\n",
       "        element.appendChild(docLink);\n",
       "      }\n",
       "    </script>\n",
       "  </div>\n",
       "\n",
       "\n",
       "<div id=\"df-11a9cf4e-e7e4-402a-a56e-0db399a353bf\">\n",
       "  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-11a9cf4e-e7e4-402a-a56e-0db399a353bf')\"\n",
       "            title=\"Suggest charts.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
       "     width=\"24px\">\n",
       "    <g>\n",
       "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
       "    </g>\n",
       "</svg>\n",
       "  </button>\n",
       "\n",
       "<style>\n",
       "  .colab-df-quickchart {\n",
       "      --bg-color: #E8F0FE;\n",
       "      --fill-color: #1967D2;\n",
       "      --hover-bg-color: #E2EBFA;\n",
       "      --hover-fill-color: #174EA6;\n",
       "      --disabled-fill-color: #AAA;\n",
       "      --disabled-bg-color: #DDD;\n",
       "  }\n",
       "\n",
       "  [theme=dark] .colab-df-quickchart {\n",
       "      --bg-color: #3B4455;\n",
       "      --fill-color: #D2E3FC;\n",
       "      --hover-bg-color: #434B5C;\n",
       "      --hover-fill-color: #FFFFFF;\n",
       "      --disabled-bg-color: #3B4455;\n",
       "      --disabled-fill-color: #666;\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart {\n",
       "    background-color: var(--bg-color);\n",
       "    border: none;\n",
       "    border-radius: 50%;\n",
       "    cursor: pointer;\n",
       "    display: none;\n",
       "    fill: var(--fill-color);\n",
       "    height: 32px;\n",
       "    padding: 0;\n",
       "    width: 32px;\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart:hover {\n",
       "    background-color: var(--hover-bg-color);\n",
       "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "    fill: var(--button-hover-fill-color);\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart-complete:disabled,\n",
       "  .colab-df-quickchart-complete:disabled:hover {\n",
       "    background-color: var(--disabled-bg-color);\n",
       "    fill: var(--disabled-fill-color);\n",
       "    box-shadow: none;\n",
       "  }\n",
       "\n",
       "  .colab-df-spinner {\n",
       "    border: 2px solid var(--fill-color);\n",
       "    border-color: transparent;\n",
       "    border-bottom-color: var(--fill-color);\n",
       "    animation:\n",
       "      spin 1s steps(1) infinite;\n",
       "  }\n",
       "\n",
       "  @keyframes spin {\n",
       "    0% {\n",
       "      border-color: transparent;\n",
       "      border-bottom-color: var(--fill-color);\n",
       "      border-left-color: var(--fill-color);\n",
       "    }\n",
       "    20% {\n",
       "      border-color: transparent;\n",
       "      border-left-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "    }\n",
       "    30% {\n",
       "      border-color: transparent;\n",
       "      border-left-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "      border-right-color: var(--fill-color);\n",
       "    }\n",
       "    40% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "    }\n",
       "    60% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "    }\n",
       "    80% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "      border-bottom-color: var(--fill-color);\n",
       "    }\n",
       "    90% {\n",
       "      border-color: transparent;\n",
       "      border-bottom-color: var(--fill-color);\n",
       "    }\n",
       "  }\n",
       "</style>\n",
       "\n",
       "  <script>\n",
       "    async function quickchart(key) {\n",
       "      const quickchartButtonEl =\n",
       "        document.querySelector('#' + key + ' button');\n",
       "      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
       "      quickchartButtonEl.classList.add('colab-df-spinner');\n",
       "      try {\n",
       "        const charts = await google.colab.kernel.invokeFunction(\n",
       "            'suggestCharts', [key], {});\n",
       "      } catch (error) {\n",
       "        console.error('Error during call to suggestCharts:', error);\n",
       "      }\n",
       "      quickchartButtonEl.classList.remove('colab-df-spinner');\n",
       "      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
       "    }\n",
       "    (() => {\n",
       "      let quickchartButtonEl =\n",
       "        document.querySelector('#df-11a9cf4e-e7e4-402a-a56e-0db399a353bf button');\n",
       "      quickchartButtonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "    })();\n",
       "  </script>\n",
       "</div>\n",
       "    </div>\n",
       "  </div>\n"
      ],
      "text/plain": [
       "                                               document  \\\n",
       "0     Kickers on my watchlist XIDE TIT SOQ PNK CPW B...   \n",
       "1     user: AAP MOVIE. 55% return for the FEA/GEED i...   \n",
       "2     user I'd be afraid to short AMZN - they are lo...   \n",
       "3                                       MNTA Over 12.00   \n",
       "4                                         OI Over 21.37   \n",
       "...                                                 ...   \n",
       "5786  Industry body CII said #discoms are likely to ...   \n",
       "5787  #Gold prices slip below Rs 46,000 as #investor...   \n",
       "5788  Workers at Bajaj Auto have agreed to a 10% wag...   \n",
       "5789  #Sharemarket LIVE: Sensex off day’s high, up 6...   \n",
       "5790  #Sensex, #Nifty climb off day's highs, still u...   \n",
       "\n",
       "                   sentence_embedding_small_bert_L2_128 sentiment  \\\n",
       "0     [-0.9530864357948303, 0.2135828286409378, 0.10...  positive   \n",
       "1     [-0.4725969433784485, 0.5354134440422058, -0.2...  positive   \n",
       "2     [0.30400288105010986, 0.22862982749938965, -0....  positive   \n",
       "3     [-1.707902193069458, -0.48472753167152405, -0....  positive   \n",
       "4     [-2.3011534214019775, 0.2649511396884918, -0.4...  positive   \n",
       "...                                                 ...       ...   \n",
       "5786  [-0.21655204892158508, 0.6153537631034851, 0.0...  positive   \n",
       "5787  [-0.19915254414081573, 0.2607441842556, 0.0032...  positive   \n",
       "5788  [-0.4361518919467926, 0.9346759915351868, -0.3...  positive   \n",
       "5789  [-0.6081278920173645, 0.2732301354408264, 0.25...  positive   \n",
       "5790  [-0.5274896621704102, 0.4326432943344116, 0.06...  positive   \n",
       "\n",
       "     sentiment_confidence                                               text  \\\n",
       "0                     1.0  Kickers on my watchlist XIDE TIT SOQ PNK CPW B...   \n",
       "1                     1.0  user: AAP MOVIE. 55% return for the FEA/GEED i...   \n",
       "2                     1.0  user I'd be afraid to short AMZN - they are lo...   \n",
       "3                     1.0                                  MNTA Over 12.00     \n",
       "4                     1.0                                   OI  Over 21.37     \n",
       "...                   ...                                                ...   \n",
       "5786                  1.0  Industry body CII said #discoms are likely to ...   \n",
       "5787                  1.0  #Gold prices slip below Rs 46,000 as #investor...   \n",
       "5788                  1.0  Workers at Bajaj Auto have agreed to a 10% wag...   \n",
       "5789                  1.0  #Sharemarket LIVE: Sensex off day’s high, up 6...   \n",
       "5790                  1.0  #Sensex, #Nifty climb off day's highs, still u...   \n",
       "\n",
       "             y  \n",
       "0     positive  \n",
       "1     positive  \n",
       "2     positive  \n",
       "3     positive  \n",
       "4     positive  \n",
       "...        ...  \n",
       "5786  negative  \n",
       "5787  negative  \n",
       "5788  positive  \n",
       "5789  positive  \n",
       "5790  positive  \n",
       "\n",
       "[5791 rows x 6 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "# load a trainable pipeline by specifying the train. prefix  and fit it on a datset with label and text columns\n",
    "# by default the Universal Sentence Encoder (USE) Sentence embeddings are used for generation\n",
    "trainable_pipe = nlp.load('train.sentiment')\n",
    "fitted_pipe = trainable_pipe.fit(train_df)\n",
    "\n",
    "# predict with the trainable pipeline on dataset and get predictions\n",
    "preds = fitted_pipe.predict(train_df,output_level='document')\n",
    "#sentence detector that is part of the pipe generates sone NaNs. lets drop them first\n",
    "preds.dropna(inplace=True)\n",
    "print(classification_report(preds['y'], preds['sentiment']))\n",
    "\n",
    "preds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lVyOE2wV0fw_"
   },
   "source": [
    "# Test the fitted pipe on new example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 150
    },
    "id": "qdCUg2MR0PD2",
    "outputId": "d68ec313-a55c-4d59-c162-4b985b238ebc"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sentence_detector_dl download started this may take some time.\n",
      "Approximate size to download 354.6 KB\n",
      "[OK!]\n",
      "Warning::Spark Session already created, some configs may not take.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "  <div id=\"df-c2f7162c-9424-4049-9307-997e8fb54d10\" class=\"colab-df-container\">\n",
       "    <div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sentence</th>\n",
       "      <th>sentence_embedding_small_bert_L2_128</th>\n",
       "      <th>sentiment</th>\n",
       "      <th>sentiment_confidence</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Bitcoin is going to the moon!</td>\n",
       "      <td>[-1.0531491041183472, -0.2827455699443817, -0....</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>\n",
       "    <div class=\"colab-df-buttons\">\n",
       "\n",
       "  <div class=\"colab-df-container\">\n",
       "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-c2f7162c-9424-4049-9307-997e8fb54d10')\"\n",
       "            title=\"Convert this dataframe to an interactive table.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
       "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
       "  </svg>\n",
       "    </button>\n",
       "\n",
       "  <style>\n",
       "    .colab-df-container {\n",
       "      display:flex;\n",
       "      gap: 12px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert {\n",
       "      background-color: #E8F0FE;\n",
       "      border: none;\n",
       "      border-radius: 50%;\n",
       "      cursor: pointer;\n",
       "      display: none;\n",
       "      fill: #1967D2;\n",
       "      height: 32px;\n",
       "      padding: 0 0 0 0;\n",
       "      width: 32px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert:hover {\n",
       "      background-color: #E2EBFA;\n",
       "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "      fill: #174EA6;\n",
       "    }\n",
       "\n",
       "    .colab-df-buttons div {\n",
       "      margin-bottom: 4px;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert {\n",
       "      background-color: #3B4455;\n",
       "      fill: #D2E3FC;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert:hover {\n",
       "      background-color: #434B5C;\n",
       "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "      fill: #FFFFFF;\n",
       "    }\n",
       "  </style>\n",
       "\n",
       "    <script>\n",
       "      const buttonEl =\n",
       "        document.querySelector('#df-c2f7162c-9424-4049-9307-997e8fb54d10 button.colab-df-convert');\n",
       "      buttonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "\n",
       "      async function convertToInteractive(key) {\n",
       "        const element = document.querySelector('#df-c2f7162c-9424-4049-9307-997e8fb54d10');\n",
       "        const dataTable =\n",
       "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
       "                                                    [key], {});\n",
       "        if (!dataTable) return;\n",
       "\n",
       "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
       "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
       "          + ' to learn more about interactive tables.';\n",
       "        element.innerHTML = '';\n",
       "        dataTable['output_type'] = 'display_data';\n",
       "        await google.colab.output.renderOutput(dataTable, element);\n",
       "        const docLink = document.createElement('div');\n",
       "        docLink.innerHTML = docLinkHtml;\n",
       "        element.appendChild(docLink);\n",
       "      }\n",
       "    </script>\n",
       "  </div>\n",
       "\n",
       "    </div>\n",
       "  </div>\n"
      ],
      "text/plain": [
       "                        sentence  \\\n",
       "0  Bitcoin is going to the moon!   \n",
       "\n",
       "                sentence_embedding_small_bert_L2_128 sentiment  \\\n",
       "0  [-1.0531491041183472, -0.2827455699443817, -0....  positive   \n",
       "\n",
       "  sentiment_confidence  \n",
       "0                  1.0  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fitted_pipe.predict(\"Bitcoin is going to the moon!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xflpwrVjjBVD"
   },
   "source": [
    "## Configure pipe training parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "UtsAUGTmOTms",
    "outputId": "8b700f48-678d-4e18-fab7-b03a3eba15b0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The following parameters are configurable for this NLU pipeline (You can copy paste the examples) :\n",
      ">>> component_list['bert_sentence_embeddings@sent_small_bert_L2_128'] has settable params:\n",
      "component_list['bert_sentence_embeddings@sent_small_bert_L2_128'].setBatchSize(8)              | Info: Size of every batch | Currently set to : 8\n",
      "component_list['bert_sentence_embeddings@sent_small_bert_L2_128'].setEngine('tensorflow')      | Info: Deep Learning engine used for this model | Currently set to : tensorflow\n",
      "component_list['bert_sentence_embeddings@sent_small_bert_L2_128'].setIsLong(False)             | Info: Use Long type instead of Int type for inputs buffer - Some Bert models require Long instead of Int. | Currently set to : False\n",
      "component_list['bert_sentence_embeddings@sent_small_bert_L2_128'].setMaxSentenceLength(128)    | Info: Max sentence length to process | Currently set to : 128\n",
      "component_list['bert_sentence_embeddings@sent_small_bert_L2_128'].setDimension(128)            | Info: Number of embedding dimensions | Currently set to : 128\n",
      "component_list['bert_sentence_embeddings@sent_small_bert_L2_128'].setCaseSensitive(False)      | Info: whether to ignore case in tokens for embeddings matching | Currently set to : False\n",
      "component_list['bert_sentence_embeddings@sent_small_bert_L2_128'].setStorageRef('sent_small_bert_L2_128')  | Info: unique reference name for identification | Currently set to : sent_small_bert_L2_128\n",
      ">>> component_list['document_assembler'] has settable params:\n",
      "component_list['document_assembler'].setCleanupMode('shrink')                                  | Info: possible values: disabled, inplace, inplace_full, shrink, shrink_full, each, each_full, delete_full | Currently set to : shrink\n",
      ">>> component_list['sentiment_dl@sent_small_bert_L2_128'] has settable params:\n",
      "component_list['sentiment_dl@sent_small_bert_L2_128'].setEngine('tensorflow')                  | Info: Deep Learning engine used for this model | Currently set to : tensorflow\n",
      "component_list['sentiment_dl@sent_small_bert_L2_128'].setThreshold(0.6)                        | Info: The minimum threshold for the final result otheriwse it will be neutral | Currently set to : 0.6\n",
      "component_list['sentiment_dl@sent_small_bert_L2_128'].setThresholdLabel('neutral')             | Info: In case the score is less than threshold, what should be the label. Default is neutral. | Currently set to : neutral\n",
      "component_list['sentiment_dl@sent_small_bert_L2_128'].setStorageRef('sent_small_bert_L2_128')  | Info: unique reference name for identification | Currently set to : sent_small_bert_L2_128\n"
     ]
    }
   ],
   "source": [
    "trainable_pipe.print_info()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2GJdDNV9jEIe"
   },
   "source": [
    "## Retrain with new parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 806
    },
    "id": "mptfvHx-MMMX",
    "outputId": "57b240b3-3c7d-4fd1-a1e7-3d0f6212a02a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning::Spark Session already created, some configs may not take.\n",
      "Warning::Spark Session already created, some configs may not take.\n",
      "sent_small_bert_L2_128 download started this may take some time.\n",
      "Approximate size to download 16.1 MB\n",
      "[OK!]\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "    negative       0.00      0.00      0.00      2106\n",
      "    positive       0.64      1.00      0.78      3685\n",
      "\n",
      "    accuracy                           0.64      5791\n",
      "   macro avg       0.32      0.50      0.39      5791\n",
      "weighted avg       0.40      0.64      0.49      5791\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "  <div id=\"df-902baaea-fe0b-49c6-9b45-f3ee5f131f96\" class=\"colab-df-container\">\n",
       "    <div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>document</th>\n",
       "      <th>sentence_embedding_small_bert_L2_128</th>\n",
       "      <th>sentiment</th>\n",
       "      <th>sentiment_confidence</th>\n",
       "      <th>text</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Kickers on my watchlist XIDE TIT SOQ PNK CPW B...</td>\n",
       "      <td>[-0.9530864357948303, 0.2135828286409378, 0.10...</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>Kickers on my watchlist XIDE TIT SOQ PNK CPW B...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>user: AAP MOVIE. 55% return for the FEA/GEED i...</td>\n",
       "      <td>[-0.4725969433784485, 0.5354134440422058, -0.2...</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>user: AAP MOVIE. 55% return for the FEA/GEED i...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>user I'd be afraid to short AMZN - they are lo...</td>\n",
       "      <td>[0.30400288105010986, 0.22862982749938965, -0....</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>user I'd be afraid to short AMZN - they are lo...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>MNTA Over 12.00</td>\n",
       "      <td>[-1.707902193069458, -0.48472753167152405, -0....</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>MNTA Over 12.00</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>OI Over 21.37</td>\n",
       "      <td>[-2.3011534214019775, 0.2649511396884918, -0.4...</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>OI  Over 21.37</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5786</th>\n",
       "      <td>Industry body CII said #discoms are likely to ...</td>\n",
       "      <td>[-0.21655204892158508, 0.6153537631034851, 0.0...</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>Industry body CII said #discoms are likely to ...</td>\n",
       "      <td>negative</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5787</th>\n",
       "      <td>#Gold prices slip below Rs 46,000 as #investor...</td>\n",
       "      <td>[-0.19915254414081573, 0.2607441842556, 0.0032...</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>#Gold prices slip below Rs 46,000 as #investor...</td>\n",
       "      <td>negative</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5788</th>\n",
       "      <td>Workers at Bajaj Auto have agreed to a 10% wag...</td>\n",
       "      <td>[-0.4361518919467926, 0.9346759915351868, -0.3...</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>Workers at Bajaj Auto have agreed to a 10% wag...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5789</th>\n",
       "      <td>#Sharemarket LIVE: Sensex off day’s high, up 6...</td>\n",
       "      <td>[-0.6081278920173645, 0.2732301354408264, 0.25...</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>#Sharemarket LIVE: Sensex off day’s high, up 6...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5790</th>\n",
       "      <td>#Sensex, #Nifty climb off day's highs, still u...</td>\n",
       "      <td>[-0.5274896621704102, 0.4326432943344116, 0.06...</td>\n",
       "      <td>positive</td>\n",
       "      <td>1.0</td>\n",
       "      <td>#Sensex, #Nifty climb off day's highs, still u...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5791 rows × 6 columns</p>\n",
       "</div>\n",
       "    <div class=\"colab-df-buttons\">\n",
       "\n",
       "  <div class=\"colab-df-container\">\n",
       "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-902baaea-fe0b-49c6-9b45-f3ee5f131f96')\"\n",
       "            title=\"Convert this dataframe to an interactive table.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
       "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
       "  </svg>\n",
       "    </button>\n",
       "\n",
       "  <style>\n",
       "    .colab-df-container {\n",
       "      display:flex;\n",
       "      gap: 12px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert {\n",
       "      background-color: #E8F0FE;\n",
       "      border: none;\n",
       "      border-radius: 50%;\n",
       "      cursor: pointer;\n",
       "      display: none;\n",
       "      fill: #1967D2;\n",
       "      height: 32px;\n",
       "      padding: 0 0 0 0;\n",
       "      width: 32px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert:hover {\n",
       "      background-color: #E2EBFA;\n",
       "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "      fill: #174EA6;\n",
       "    }\n",
       "\n",
       "    .colab-df-buttons div {\n",
       "      margin-bottom: 4px;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert {\n",
       "      background-color: #3B4455;\n",
       "      fill: #D2E3FC;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert:hover {\n",
       "      background-color: #434B5C;\n",
       "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "      fill: #FFFFFF;\n",
       "    }\n",
       "  </style>\n",
       "\n",
       "    <script>\n",
       "      const buttonEl =\n",
       "        document.querySelector('#df-902baaea-fe0b-49c6-9b45-f3ee5f131f96 button.colab-df-convert');\n",
       "      buttonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "\n",
       "      async function convertToInteractive(key) {\n",
       "        const element = document.querySelector('#df-902baaea-fe0b-49c6-9b45-f3ee5f131f96');\n",
       "        const dataTable =\n",
       "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
       "                                                    [key], {});\n",
       "        if (!dataTable) return;\n",
       "\n",
       "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
       "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
       "          + ' to learn more about interactive tables.';\n",
       "        element.innerHTML = '';\n",
       "        dataTable['output_type'] = 'display_data';\n",
       "        await google.colab.output.renderOutput(dataTable, element);\n",
       "        const docLink = document.createElement('div');\n",
       "        docLink.innerHTML = docLinkHtml;\n",
       "        element.appendChild(docLink);\n",
       "      }\n",
       "    </script>\n",
       "  </div>\n",
       "\n",
       "\n",
       "<div id=\"df-5773d7e9-8d96-4cd5-a86e-a64909015e99\">\n",
       "  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-5773d7e9-8d96-4cd5-a86e-a64909015e99')\"\n",
       "            title=\"Suggest charts.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
       "     width=\"24px\">\n",
       "    <g>\n",
       "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
       "    </g>\n",
       "</svg>\n",
       "  </button>\n",
       "\n",
       "<style>\n",
       "  .colab-df-quickchart {\n",
       "      --bg-color: #E8F0FE;\n",
       "      --fill-color: #1967D2;\n",
       "      --hover-bg-color: #E2EBFA;\n",
       "      --hover-fill-color: #174EA6;\n",
       "      --disabled-fill-color: #AAA;\n",
       "      --disabled-bg-color: #DDD;\n",
       "  }\n",
       "\n",
       "  [theme=dark] .colab-df-quickchart {\n",
       "      --bg-color: #3B4455;\n",
       "      --fill-color: #D2E3FC;\n",
       "      --hover-bg-color: #434B5C;\n",
       "      --hover-fill-color: #FFFFFF;\n",
       "      --disabled-bg-color: #3B4455;\n",
       "      --disabled-fill-color: #666;\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart {\n",
       "    background-color: var(--bg-color);\n",
       "    border: none;\n",
       "    border-radius: 50%;\n",
       "    cursor: pointer;\n",
       "    display: none;\n",
       "    fill: var(--fill-color);\n",
       "    height: 32px;\n",
       "    padding: 0;\n",
       "    width: 32px;\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart:hover {\n",
       "    background-color: var(--hover-bg-color);\n",
       "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "    fill: var(--button-hover-fill-color);\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart-complete:disabled,\n",
       "  .colab-df-quickchart-complete:disabled:hover {\n",
       "    background-color: var(--disabled-bg-color);\n",
       "    fill: var(--disabled-fill-color);\n",
       "    box-shadow: none;\n",
       "  }\n",
       "\n",
       "  .colab-df-spinner {\n",
       "    border: 2px solid var(--fill-color);\n",
       "    border-color: transparent;\n",
       "    border-bottom-color: var(--fill-color);\n",
       "    animation:\n",
       "      spin 1s steps(1) infinite;\n",
       "  }\n",
       "\n",
       "  @keyframes spin {\n",
       "    0% {\n",
       "      border-color: transparent;\n",
       "      border-bottom-color: var(--fill-color);\n",
       "      border-left-color: var(--fill-color);\n",
       "    }\n",
       "    20% {\n",
       "      border-color: transparent;\n",
       "      border-left-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "    }\n",
       "    30% {\n",
       "      border-color: transparent;\n",
       "      border-left-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "      border-right-color: var(--fill-color);\n",
       "    }\n",
       "    40% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "    }\n",
       "    60% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "    }\n",
       "    80% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "      border-bottom-color: var(--fill-color);\n",
       "    }\n",
       "    90% {\n",
       "      border-color: transparent;\n",
       "      border-bottom-color: var(--fill-color);\n",
       "    }\n",
       "  }\n",
       "</style>\n",
       "\n",
       "  <script>\n",
       "    async function quickchart(key) {\n",
       "      const quickchartButtonEl =\n",
       "        document.querySelector('#' + key + ' button');\n",
       "      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
       "      quickchartButtonEl.classList.add('colab-df-spinner');\n",
       "      try {\n",
       "        const charts = await google.colab.kernel.invokeFunction(\n",
       "            'suggestCharts', [key], {});\n",
       "      } catch (error) {\n",
       "        console.error('Error during call to suggestCharts:', error);\n",
       "      }\n",
       "      quickchartButtonEl.classList.remove('colab-df-spinner');\n",
       "      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
       "    }\n",
       "    (() => {\n",
       "      let quickchartButtonEl =\n",
       "        document.querySelector('#df-5773d7e9-8d96-4cd5-a86e-a64909015e99 button');\n",
       "      quickchartButtonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "    })();\n",
       "  </script>\n",
       "</div>\n",
       "    </div>\n",
       "  </div>\n"
      ],
      "text/plain": [
       "                                               document  \\\n",
       "0     Kickers on my watchlist XIDE TIT SOQ PNK CPW B...   \n",
       "1     user: AAP MOVIE. 55% return for the FEA/GEED i...   \n",
       "2     user I'd be afraid to short AMZN - they are lo...   \n",
       "3                                       MNTA Over 12.00   \n",
       "4                                         OI Over 21.37   \n",
       "...                                                 ...   \n",
       "5786  Industry body CII said #discoms are likely to ...   \n",
       "5787  #Gold prices slip below Rs 46,000 as #investor...   \n",
       "5788  Workers at Bajaj Auto have agreed to a 10% wag...   \n",
       "5789  #Sharemarket LIVE: Sensex off day’s high, up 6...   \n",
       "5790  #Sensex, #Nifty climb off day's highs, still u...   \n",
       "\n",
       "                   sentence_embedding_small_bert_L2_128 sentiment  \\\n",
       "0     [-0.9530864357948303, 0.2135828286409378, 0.10...  positive   \n",
       "1     [-0.4725969433784485, 0.5354134440422058, -0.2...  positive   \n",
       "2     [0.30400288105010986, 0.22862982749938965, -0....  positive   \n",
       "3     [-1.707902193069458, -0.48472753167152405, -0....  positive   \n",
       "4     [-2.3011534214019775, 0.2649511396884918, -0.4...  positive   \n",
       "...                                                 ...       ...   \n",
       "5786  [-0.21655204892158508, 0.6153537631034851, 0.0...  positive   \n",
       "5787  [-0.19915254414081573, 0.2607441842556, 0.0032...  positive   \n",
       "5788  [-0.4361518919467926, 0.9346759915351868, -0.3...  positive   \n",
       "5789  [-0.6081278920173645, 0.2732301354408264, 0.25...  positive   \n",
       "5790  [-0.5274896621704102, 0.4326432943344116, 0.06...  positive   \n",
       "\n",
       "     sentiment_confidence                                               text  \\\n",
       "0                     1.0  Kickers on my watchlist XIDE TIT SOQ PNK CPW B...   \n",
       "1                     1.0  user: AAP MOVIE. 55% return for the FEA/GEED i...   \n",
       "2                     1.0  user I'd be afraid to short AMZN - they are lo...   \n",
       "3                     1.0                                  MNTA Over 12.00     \n",
       "4                     1.0                                   OI  Over 21.37     \n",
       "...                   ...                                                ...   \n",
       "5786                  1.0  Industry body CII said #discoms are likely to ...   \n",
       "5787                  1.0  #Gold prices slip below Rs 46,000 as #investor...   \n",
       "5788                  1.0  Workers at Bajaj Auto have agreed to a 10% wag...   \n",
       "5789                  1.0  #Sharemarket LIVE: Sensex off day’s high, up 6...   \n",
       "5790                  1.0  #Sensex, #Nifty climb off day's highs, still u...   \n",
       "\n",
       "             y  \n",
       "0     positive  \n",
       "1     positive  \n",
       "2     positive  \n",
       "3     positive  \n",
       "4     positive  \n",
       "...        ...  \n",
       "5786  negative  \n",
       "5787  negative  \n",
       "5788  positive  \n",
       "5789  positive  \n",
       "5790  positive  \n",
       "\n",
       "[5791 rows x 6 columns]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Train longer!\n",
    "trainable_pipe = nlp.load('train.sentiment')\n",
    "trainable_pipe['trainable_sentiment_dl'].setMaxEpochs(5)\n",
    "fitted_pipe = trainable_pipe.fit(train_df)\n",
    "# predict with the trainable pipeline on dataset and get predictions\n",
    "preds = fitted_pipe.predict(train_df,output_level='document')\n",
    "\n",
    "#sentence detector that is part of the pipe generates sone NaNs. lets drop them first\n",
    "preds.dropna(inplace=True)\n",
    "print(classification_report(preds['y'], preds['sentiment']))\n",
    "\n",
    "preds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qFoT-s1MjTSS"
   },
   "source": [
    "# Try training with different Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "nxWFzQOhjWC8",
    "outputId": "34e045c7-7402-448a-d52e-cb6ce06c9a00"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "For language <am> NLU provides the following Models : \n",
      "nlu.load('am.embed_sentence.xlm_roberta') returns Spark NLP model_anno_obj sent_xlm_roberta_base_finetuned_amharic\n",
      "For language <de> NLU provides the following Models : \n",
      "nlu.load('de.embed_sentence.bert.base_cased') returns Spark NLP model_anno_obj sent_bert_base_cased\n",
      "For language <el> NLU provides the following Models : \n",
      "nlu.load('el.embed_sentence.bert.base_uncased') returns Spark NLP model_anno_obj sent_bert_base_uncased\n",
      "For language <en> NLU provides the following Models : \n",
      "nlu.load('en.embed_sentence') returns Spark NLP model_anno_obj tfhub_use\n",
      "nlu.load('en.embed_sentence.albert') returns Spark NLP model_anno_obj albert_base_uncased\n",
      "nlu.load('en.embed_sentence.bert') returns Spark NLP model_anno_obj sent_bert_base_uncased\n",
      "nlu.load('en.embed_sentence.bert.base_uncased_legal') returns Spark NLP model_anno_obj sent_bert_base_uncased_legal\n",
      "nlu.load('en.embed_sentence.bert.finetuned') returns Spark NLP model_anno_obj sbert_setfit_finetuned_financial_text_classification\n",
      "nlu.load('en.embed_sentence.bert.pubmed') returns Spark NLP model_anno_obj sent_bert_pubmed\n",
      "nlu.load('en.embed_sentence.bert.pubmed_squad2') returns Spark NLP model_anno_obj sent_bert_pubmed_squad2\n",
      "nlu.load('en.embed_sentence.bert.wiki_books') returns Spark NLP model_anno_obj sent_bert_wiki_books\n",
      "nlu.load('en.embed_sentence.bert.wiki_books_mnli') returns Spark NLP model_anno_obj sent_bert_wiki_books_mnli\n",
      "nlu.load('en.embed_sentence.bert.wiki_books_qnli') returns Spark NLP model_anno_obj sent_bert_wiki_books_qnli\n",
      "nlu.load('en.embed_sentence.bert.wiki_books_qqp') returns Spark NLP model_anno_obj sent_bert_wiki_books_qqp\n",
      "nlu.load('en.embed_sentence.bert.wiki_books_squad2') returns Spark NLP model_anno_obj sent_bert_wiki_books_squad2\n",
      "nlu.load('en.embed_sentence.bert.wiki_books_sst2') returns Spark NLP model_anno_obj sent_bert_wiki_books_sst2\n",
      "nlu.load('en.embed_sentence.bert_base_cased') returns Spark NLP model_anno_obj sent_bert_base_cased\n",
      "nlu.load('en.embed_sentence.bert_base_uncased') returns Spark NLP model_anno_obj sent_bert_base_uncased\n",
      "nlu.load('en.embed_sentence.bert_large_cased') returns Spark NLP model_anno_obj sent_bert_large_cased\n",
      "nlu.load('en.embed_sentence.bert_large_uncased') returns Spark NLP model_anno_obj sent_bert_large_uncased\n",
      "nlu.load('en.embed_sentence.bert_use_cmlm_en_base') returns Spark NLP model_anno_obj sent_bert_use_cmlm_en_base\n",
      "nlu.load('en.embed_sentence.bert_use_cmlm_en_large') returns Spark NLP model_anno_obj sent_bert_use_cmlm_en_large\n",
      "nlu.load('en.embed_sentence.biobert.clinical_base_cased') returns Spark NLP model_anno_obj sent_biobert_clinical_base_cased\n",
      "nlu.load('en.embed_sentence.biobert.discharge_base_cased') returns Spark NLP model_anno_obj sent_biobert_discharge_base_cased\n",
      "nlu.load('en.embed_sentence.biobert.pmc_base_cased') returns Spark NLP model_anno_obj sent_biobert_pmc_base_cased\n",
      "nlu.load('en.embed_sentence.biobert.pubmed_base_cased') returns Spark NLP model_anno_obj sent_biobert_pubmed_base_cased\n",
      "nlu.load('en.embed_sentence.biobert.pubmed_large_cased') returns Spark NLP model_anno_obj sent_biobert_pubmed_large_cased\n",
      "nlu.load('en.embed_sentence.biobert.pubmed_pmc_base_cased') returns Spark NLP model_anno_obj sent_biobert_pubmed_pmc_base_cased\n",
      "nlu.load('en.embed_sentence.covidbert.large_uncased') returns Spark NLP model_anno_obj sent_covidbert_large_uncased\n",
      "nlu.load('en.embed_sentence.distil_roberta.distilled_base') returns Spark NLP model_anno_obj sent_distilroberta_base\n",
      "nlu.load('en.embed_sentence.doc2vec') returns Spark NLP model_anno_obj doc2vec_gigaword_300\n",
      "nlu.load('en.embed_sentence.doc2vec.gigaword_300') returns Spark NLP model_anno_obj doc2vec_gigaword_300\n",
      "nlu.load('en.embed_sentence.doc2vec.gigaword_wiki_300') returns Spark NLP model_anno_obj doc2vec_gigaword_wiki_300\n",
      "nlu.load('en.embed_sentence.electra') returns Spark NLP model_anno_obj sent_electra_small_uncased\n",
      "nlu.load('en.embed_sentence.electra_base_uncased') returns Spark NLP model_anno_obj sent_electra_base_uncased\n",
      "nlu.load('en.embed_sentence.electra_large_uncased') returns Spark NLP model_anno_obj sent_electra_large_uncased\n",
      "nlu.load('en.embed_sentence.electra_small_uncased') returns Spark NLP model_anno_obj sent_electra_small_uncased\n",
      "nlu.load('en.embed_sentence.roberta.base') returns Spark NLP model_anno_obj sent_roberta_base\n",
      "nlu.load('en.embed_sentence.roberta.large') returns Spark NLP model_anno_obj sent_roberta_large\n",
      "nlu.load('en.embed_sentence.small_bert_L10_128') returns Spark NLP model_anno_obj sent_small_bert_L10_128\n",
      "nlu.load('en.embed_sentence.small_bert_L10_256') returns Spark NLP model_anno_obj sent_small_bert_L10_256\n",
      "nlu.load('en.embed_sentence.small_bert_L10_512') returns Spark NLP model_anno_obj sent_small_bert_L10_512\n",
      "nlu.load('en.embed_sentence.small_bert_L10_768') returns Spark NLP model_anno_obj sent_small_bert_L10_768\n",
      "nlu.load('en.embed_sentence.small_bert_L12_128') returns Spark NLP model_anno_obj sent_small_bert_L12_128\n",
      "nlu.load('en.embed_sentence.small_bert_L12_256') returns Spark NLP model_anno_obj sent_small_bert_L12_256\n",
      "nlu.load('en.embed_sentence.small_bert_L12_512') returns Spark NLP model_anno_obj sent_small_bert_L12_512\n",
      "nlu.load('en.embed_sentence.small_bert_L12_768') returns Spark NLP model_anno_obj sent_small_bert_L12_768\n",
      "nlu.load('en.embed_sentence.small_bert_L2_128') returns Spark NLP model_anno_obj sent_small_bert_L2_128\n",
      "nlu.load('en.embed_sentence.small_bert_L2_256') returns Spark NLP model_anno_obj sent_small_bert_L2_256\n",
      "nlu.load('en.embed_sentence.small_bert_L2_512') returns Spark NLP model_anno_obj sent_small_bert_L2_512\n",
      "nlu.load('en.embed_sentence.small_bert_L2_768') returns Spark NLP model_anno_obj sent_small_bert_L2_768\n",
      "nlu.load('en.embed_sentence.small_bert_L4_128') returns Spark NLP model_anno_obj sent_small_bert_L4_128\n",
      "nlu.load('en.embed_sentence.small_bert_L4_256') returns Spark NLP model_anno_obj sent_small_bert_L4_256\n",
      "nlu.load('en.embed_sentence.small_bert_L4_512') returns Spark NLP model_anno_obj sent_small_bert_L4_512\n",
      "nlu.load('en.embed_sentence.small_bert_L4_768') returns Spark NLP model_anno_obj sent_small_bert_L4_768\n",
      "nlu.load('en.embed_sentence.small_bert_L6_128') returns Spark NLP model_anno_obj sent_small_bert_L6_128\n",
      "nlu.load('en.embed_sentence.small_bert_L6_256') returns Spark NLP model_anno_obj sent_small_bert_L6_256\n",
      "nlu.load('en.embed_sentence.small_bert_L6_512') returns Spark NLP model_anno_obj sent_small_bert_L6_512\n",
      "nlu.load('en.embed_sentence.small_bert_L6_768') returns Spark NLP model_anno_obj sent_small_bert_L6_768\n",
      "nlu.load('en.embed_sentence.small_bert_L8_128') returns Spark NLP model_anno_obj sent_small_bert_L8_128\n",
      "nlu.load('en.embed_sentence.small_bert_L8_256') returns Spark NLP model_anno_obj sent_small_bert_L8_256\n",
      "nlu.load('en.embed_sentence.small_bert_L8_512') returns Spark NLP model_anno_obj sent_small_bert_L8_512\n",
      "nlu.load('en.embed_sentence.small_bert_L8_768') returns Spark NLP model_anno_obj sent_small_bert_L8_768\n",
      "nlu.load('en.embed_sentence.tfhub_use') returns Spark NLP model_anno_obj tfhub_use\n",
      "nlu.load('en.embed_sentence.tfhub_use.lg') returns Spark NLP model_anno_obj tfhub_use_lg\n",
      "nlu.load('en.embed_sentence.use') returns Spark NLP model_anno_obj tfhub_use\n",
      "nlu.load('en.embed_sentence.use.lg') returns Spark NLP model_anno_obj tfhub_use_lg\n",
      "For language <es> NLU provides the following Models : \n",
      "nlu.load('es.embed_sentence.bert.base_cased') returns Spark NLP model_anno_obj sent_bert_base_cased\n",
      "nlu.load('es.embed_sentence.bert.base_uncased') returns Spark NLP model_anno_obj sent_bert_base_uncased\n",
      "For language <fi> NLU provides the following Models : \n",
      "nlu.load('fi.embed_sentence.bert') returns Spark NLP model_anno_obj bert_base_finnish_uncased\n",
      "nlu.load('fi.embed_sentence.bert.cased') returns Spark NLP model_anno_obj bert_base_finnish_cased\n",
      "nlu.load('fi.embed_sentence.bert.uncased') returns Spark NLP model_anno_obj bert_base_finnish_uncased\n",
      "For language <ha> NLU provides the following Models : \n",
      "nlu.load('ha.embed_sentence.xlm_roberta') returns Spark NLP model_anno_obj sent_xlm_roberta_base_finetuned_hausa\n",
      "For language <ig> NLU provides the following Models : \n",
      "nlu.load('ig.embed_sentence.xlm_roberta') returns Spark NLP model_anno_obj sent_xlm_roberta_base_finetuned_igbo\n",
      "For language <lg> NLU provides the following Models : \n",
      "nlu.load('lg.embed_sentence.xlm_roberta') returns Spark NLP model_anno_obj sent_xlm_roberta_base_finetuned_luganda\n",
      "For language <nl> NLU provides the following Models : \n",
      "nlu.load('nl.embed_sentence.bert.base_cased') returns Spark NLP model_anno_obj sent_bert_base_cased\n",
      "For language <pcm> NLU provides the following Models : \n",
      "nlu.load('pcm.embed_sentence.xlm_roberta') returns Spark NLP model_anno_obj sent_xlm_roberta_base_finetuned_naija\n",
      "For language <pt> NLU provides the following Models : \n",
      "nlu.load('pt.embed_sentence.bert.base_legal') returns Spark NLP model_anno_obj sbert_legal_bertimbau_base_tsdae_sts\n",
      "nlu.load('pt.embed_sentence.bert.cased_large_legal') returns Spark NLP model_anno_obj sbert_bert_large_portuguese_cased_legal_mlm_sts_v0.1\n",
      "nlu.load('pt.embed_sentence.bert.large_legal') returns Spark NLP model_anno_obj sbert_legal_bertimbau_large_gpl_sts\n",
      "nlu.load('pt.embed_sentence.bert.legal.cased_large_mlm_sts_v0.10.by_stjiris') returns Spark NLP model_anno_obj sbert_bert_large_portuguese_cased_legal_mlm_sts_v0.10\n",
      "nlu.load('pt.embed_sentence.bert.legal.cased_large_mlm_sts_v0.2.by_stjiris') returns Spark NLP model_anno_obj sbert_bert_large_portuguese_cased_legal_mlm_sts_v0.2\n",
      "nlu.load('pt.embed_sentence.bert.legal.cased_large_mlm_sts_v0.3.by_stjiris') returns Spark NLP model_anno_obj sbert_bert_large_portuguese_cased_legal_mlm_sts_v0.3\n",
      "nlu.load('pt.embed_sentence.bert.legal.cased_large_mlm_sts_v0.4.by_stjiris') returns Spark NLP model_anno_obj sbert_bert_large_portuguese_cased_legal_mlm_sts_v0.4\n",
      "nlu.load('pt.embed_sentence.bert.legal.cased_large_mlm_sts_v0.5.by_stjiris') returns Spark NLP model_anno_obj sbert_bert_large_portuguese_cased_legal_mlm_sts_v0.5\n",
      "nlu.load('pt.embed_sentence.bert.legal.cased_large_mlm_sts_v0.7.by_stjiris') returns Spark NLP model_anno_obj sbert_bert_large_portuguese_cased_legal_mlm_sts_v0.7\n",
      "nlu.load('pt.embed_sentence.bert.legal.cased_large_mlm_sts_v0.8.by_stjiris') returns Spark NLP model_anno_obj sbert_bert_large_portuguese_cased_legal_mlm_sts_v0.8\n",
      "nlu.load('pt.embed_sentence.bert.legal.cased_large_mlm_sts_v0.9.by_stjiris') returns Spark NLP model_anno_obj sbert_bert_large_portuguese_cased_legal_mlm_sts_v0.9\n",
      "nlu.load('pt.embed_sentence.bert.legal.cased_large_mlm_sts_v1.0.by_stjiris') returns Spark NLP model_anno_obj sbert_bert_large_portuguese_cased_legal_mlm_sts_v1.0\n",
      "nlu.load('pt.embed_sentence.bert.legal.cased_large_mlm_v0.11_gpl_nli_sts_v0.by_stjiris') returns Spark NLP model_anno_obj sbert_bert_large_portuguese_cased_legal_mlm_v0.11_gpl_nli_sts_v0\n",
      "nlu.load('pt.embed_sentence.bert.legal.cased_large_mlm_v0.11_gpl_nli_sts_v1.by_stjiris') returns Spark NLP model_anno_obj sbert_bert_large_portuguese_cased_legal_mlm_v0.11_gpl_nli_sts_v1\n",
      "nlu.load('pt.embed_sentence.bert.legal.cased_large_mlm_v0.11_nli_sts_v0.by_stjiris') returns Spark NLP model_anno_obj sbert_bert_large_portuguese_cased_legal_mlm_v0.11_nli_sts_v0\n",
      "nlu.load('pt.embed_sentence.bert.legal.cased_large_mlm_v0.11_nli_sts_v1.by_stjiris') returns Spark NLP model_anno_obj sbert_bert_large_portuguese_cased_legal_mlm_v0.11_nli_sts_v1\n",
      "nlu.load('pt.embed_sentence.bert.legal.cased_large_mlm_v0.11_sts_v0.by_stjiris') returns Spark NLP model_anno_obj sbert_bert_large_portuguese_cased_legal_mlm_v0.11_sts_v0\n",
      "nlu.load('pt.embed_sentence.bert.legal.cased_large_mlm_v0.11_sts_v1.by_stjiris') returns Spark NLP model_anno_obj sbert_bert_large_portuguese_cased_legal_mlm_v0.11_sts_v1\n",
      "nlu.load('pt.embed_sentence.bert.v2_base_legal') returns Spark NLP model_anno_obj sbert_legal_bertimbau_sts_base_ma_v2\n",
      "nlu.load('pt.embed_sentence.bert.v2_large_legal') returns Spark NLP model_anno_obj sbert_legal_bertimbau_large_tsdae_sts_v2\n",
      "nlu.load('pt.embed_sentence.bertimbau.legal.assin.base.by_rufimelo') returns Spark NLP model_anno_obj sbert_legal_bertimbau_sts_base_ma\n",
      "nlu.load('pt.embed_sentence.bertimbau.legal.assin2.base.by_rufimelo') returns Spark NLP model_anno_obj sbert_legal_bertimbau_sts_base\n",
      "nlu.load('pt.embed_sentence.bertimbau.legal.large_sts_by_rufimelo') returns Spark NLP model_anno_obj sbert_legal_bertimbau_sts_large\n",
      "nlu.load('pt.embed_sentence.bertimbau.legal.large_sts_ma.by_rufimelo') returns Spark NLP model_anno_obj sbert_legal_bertimbau_sts_large_ma\n",
      "nlu.load('pt.embed_sentence.bertimbau.legal.large_sts_ma_v3.by_rufimelo') returns Spark NLP model_anno_obj sbert_legal_bertimbau_sts_large_ma_v3\n",
      "nlu.load('pt.embed_sentence.bertimbau.legal.large_tsdae_sts.by_rufimelo') returns Spark NLP model_anno_obj sbert_legal_bertimbau_large_tsdae_sts\n",
      "nlu.load('pt.embed_sentence.bertimbau.legal.large_tsdae_sts_v4.by_rufimelo') returns Spark NLP model_anno_obj sbert_legal_bertimbau_large_tsdae_sts_v4\n",
      "nlu.load('pt.embed_sentence.bertimbau.legal.large_tsdae_v4_gpl_sts.by_rufimelo') returns Spark NLP model_anno_obj sbert_legal_bertimbau_large_tsdae_v4_gpl_sts\n",
      "nlu.load('pt.embed_sentence.bertimbau.legal.v2_large_sts_v2.by_rufimelo') returns Spark NLP model_anno_obj sbert_legal_bertimbau_sts_large_v2\n",
      "nlu.load('pt.embed_sentence.bertimbau.legal.v2_large_v2_sts.by_rufimelo') returns Spark NLP model_anno_obj sbert_legal_bertimbau_large_v2_sts\n",
      "For language <rw> NLU provides the following Models : \n",
      "nlu.load('rw.embed_sentence.xlm_roberta') returns Spark NLP model_anno_obj sent_xlm_roberta_base_finetuned_kinyarwanda\n",
      "For language <sv> NLU provides the following Models : \n",
      "nlu.load('sv.embed_sentence.bert.base_cased') returns Spark NLP model_anno_obj sent_bert_base_cased\n",
      "For language <sw> NLU provides the following Models : \n",
      "nlu.load('sw.embed_sentence.xlm_roberta') returns Spark NLP model_anno_obj sent_xlm_roberta_base_finetuned_swahili\n",
      "For language <wo> NLU provides the following Models : \n",
      "nlu.load('wo.embed_sentence.xlm_roberta') returns Spark NLP model_anno_obj sent_xlm_roberta_base_finetuned_wolof\n",
      "For language <xx> NLU provides the following Models : \n",
      "nlu.load('xx.embed_sentence') returns Spark NLP model_anno_obj sent_bert_multi_cased\n",
      "nlu.load('xx.embed_sentence.bert') returns Spark NLP model_anno_obj sent_bert_multi_cased\n",
      "nlu.load('xx.embed_sentence.bert.cased') returns Spark NLP model_anno_obj sent_bert_multi_cased\n",
      "nlu.load('xx.embed_sentence.bert.muril') returns Spark NLP model_anno_obj sent_bert_muril\n",
      "nlu.load('xx.embed_sentence.bert_use_cmlm_multi_base') returns Spark NLP model_anno_obj sent_bert_use_cmlm_multi_base\n",
      "nlu.load('xx.embed_sentence.bert_use_cmlm_multi_base_br') returns Spark NLP model_anno_obj sent_bert_use_cmlm_multi_base_br\n",
      "nlu.load('xx.embed_sentence.labse') returns Spark NLP model_anno_obj labse\n",
      "nlu.load('xx.embed_sentence.xlm_roberta.base') returns Spark NLP model_anno_obj sent_xlm_roberta_base\n",
      "For language <yo> NLU provides the following Models : \n",
      "nlu.load('yo.embed_sentence.xlm_roberta') returns Spark NLP model_anno_obj sent_xlm_roberta_base_finetuned_yoruba\n",
      "For language <zh> NLU provides the following Models : \n",
      "nlu.load('zh.embed_sentence.bert') returns Spark NLP model_anno_obj sbert_chinese_qmc_finance_v1\n",
      "nlu.load('zh.embed_sentence.bert.distilled') returns Spark NLP model_anno_obj sbert_chinese_qmc_finance_v1_distill\n"
     ]
    }
   ],
   "source": [
    "# We can use nlu.print_components(action='embed_sentence') to see every possibler sentence embedding we could use. Lets use bert!\n",
    "nlp.nlu.print_components(action='embed_sentence')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 858
    },
    "id": "IKK_Ii_gjJfF",
    "outputId": "2650912d-7378-4149-9265-8d5e662ebeb8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning::Spark Session already created, some configs may not take.\n",
      "Warning::Spark Session already created, some configs may not take.\n",
      "sent_small_bert_L2_128 download started this may take some time.\n",
      "Approximate size to download 16.1 MB\n",
      "[OK!]\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "    negative       0.69      0.24      0.36      2106\n",
      "     neutral       0.00      0.00      0.00         0\n",
      "    positive       0.72      0.85      0.78      3685\n",
      "\n",
      "    accuracy                           0.63      5791\n",
      "   macro avg       0.47      0.36      0.38      5791\n",
      "weighted avg       0.71      0.63      0.63      5791\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "  <div id=\"df-c6b69be8-f6b6-4a05-8506-f53dcb65f689\" class=\"colab-df-container\">\n",
       "    <div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>document</th>\n",
       "      <th>sentence_embedding_bert</th>\n",
       "      <th>sentiment</th>\n",
       "      <th>sentiment_confidence</th>\n",
       "      <th>text</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Kickers on my watchlist XIDE TIT SOQ PNK CPW B...</td>\n",
       "      <td>[-0.9530864357948303, 0.2135828286409378, 0.10...</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.0</td>\n",
       "      <td>Kickers on my watchlist XIDE TIT SOQ PNK CPW B...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>user: AAP MOVIE. 55% return for the FEA/GEED i...</td>\n",
       "      <td>[-0.4725969433784485, 0.5354134440422058, -0.2...</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.0</td>\n",
       "      <td>user: AAP MOVIE. 55% return for the FEA/GEED i...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>user I'd be afraid to short AMZN - they are lo...</td>\n",
       "      <td>[0.30400288105010986, 0.22862982749938965, -0....</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.0</td>\n",
       "      <td>user I'd be afraid to short AMZN - they are lo...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>MNTA Over 12.00</td>\n",
       "      <td>[-1.707902193069458, -0.48472753167152405, -0....</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.0</td>\n",
       "      <td>MNTA Over 12.00</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>OI Over 21.37</td>\n",
       "      <td>[-2.3011534214019775, 0.2649511396884918, -0.4...</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.0</td>\n",
       "      <td>OI  Over 21.37</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5786</th>\n",
       "      <td>Industry body CII said #discoms are likely to ...</td>\n",
       "      <td>[-0.21655204892158508, 0.6153537631034851, 0.0...</td>\n",
       "      <td>negative</td>\n",
       "      <td>0.0</td>\n",
       "      <td>Industry body CII said #discoms are likely to ...</td>\n",
       "      <td>negative</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5787</th>\n",
       "      <td>#Gold prices slip below Rs 46,000 as #investor...</td>\n",
       "      <td>[-0.19915254414081573, 0.2607441842556, 0.0032...</td>\n",
       "      <td>negative</td>\n",
       "      <td>0.0</td>\n",
       "      <td>#Gold prices slip below Rs 46,000 as #investor...</td>\n",
       "      <td>negative</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5788</th>\n",
       "      <td>Workers at Bajaj Auto have agreed to a 10% wag...</td>\n",
       "      <td>[-0.4361518919467926, 0.9346759915351868, -0.3...</td>\n",
       "      <td>negative</td>\n",
       "      <td>0.0</td>\n",
       "      <td>Workers at Bajaj Auto have agreed to a 10% wag...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5789</th>\n",
       "      <td>#Sharemarket LIVE: Sensex off day’s high, up 6...</td>\n",
       "      <td>[-0.6081278920173645, 0.2732301354408264, 0.25...</td>\n",
       "      <td>neutral</td>\n",
       "      <td>0.0</td>\n",
       "      <td>#Sharemarket LIVE: Sensex off day’s high, up 6...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5790</th>\n",
       "      <td>#Sensex, #Nifty climb off day's highs, still u...</td>\n",
       "      <td>[-0.5274896621704102, 0.4326432943344116, 0.06...</td>\n",
       "      <td>neutral</td>\n",
       "      <td>0.0</td>\n",
       "      <td>#Sensex, #Nifty climb off day's highs, still u...</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5791 rows × 6 columns</p>\n",
       "</div>\n",
       "    <div class=\"colab-df-buttons\">\n",
       "\n",
       "  <div class=\"colab-df-container\">\n",
       "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-c6b69be8-f6b6-4a05-8506-f53dcb65f689')\"\n",
       "            title=\"Convert this dataframe to an interactive table.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
       "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
       "  </svg>\n",
       "    </button>\n",
       "\n",
       "  <style>\n",
       "    .colab-df-container {\n",
       "      display:flex;\n",
       "      gap: 12px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert {\n",
       "      background-color: #E8F0FE;\n",
       "      border: none;\n",
       "      border-radius: 50%;\n",
       "      cursor: pointer;\n",
       "      display: none;\n",
       "      fill: #1967D2;\n",
       "      height: 32px;\n",
       "      padding: 0 0 0 0;\n",
       "      width: 32px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert:hover {\n",
       "      background-color: #E2EBFA;\n",
       "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "      fill: #174EA6;\n",
       "    }\n",
       "\n",
       "    .colab-df-buttons div {\n",
       "      margin-bottom: 4px;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert {\n",
       "      background-color: #3B4455;\n",
       "      fill: #D2E3FC;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert:hover {\n",
       "      background-color: #434B5C;\n",
       "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "      fill: #FFFFFF;\n",
       "    }\n",
       "  </style>\n",
       "\n",
       "    <script>\n",
       "      const buttonEl =\n",
       "        document.querySelector('#df-c6b69be8-f6b6-4a05-8506-f53dcb65f689 button.colab-df-convert');\n",
       "      buttonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "\n",
       "      async function convertToInteractive(key) {\n",
       "        const element = document.querySelector('#df-c6b69be8-f6b6-4a05-8506-f53dcb65f689');\n",
       "        const dataTable =\n",
       "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
       "                                                    [key], {});\n",
       "        if (!dataTable) return;\n",
       "\n",
       "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
       "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
       "          + ' to learn more about interactive tables.';\n",
       "        element.innerHTML = '';\n",
       "        dataTable['output_type'] = 'display_data';\n",
       "        await google.colab.output.renderOutput(dataTable, element);\n",
       "        const docLink = document.createElement('div');\n",
       "        docLink.innerHTML = docLinkHtml;\n",
       "        element.appendChild(docLink);\n",
       "      }\n",
       "    </script>\n",
       "  </div>\n",
       "\n",
       "\n",
       "<div id=\"df-96849220-b3fb-4dbd-84f7-92217950240d\">\n",
       "  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-96849220-b3fb-4dbd-84f7-92217950240d')\"\n",
       "            title=\"Suggest charts.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
       "     width=\"24px\">\n",
       "    <g>\n",
       "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
       "    </g>\n",
       "</svg>\n",
       "  </button>\n",
       "\n",
       "<style>\n",
       "  .colab-df-quickchart {\n",
       "      --bg-color: #E8F0FE;\n",
       "      --fill-color: #1967D2;\n",
       "      --hover-bg-color: #E2EBFA;\n",
       "      --hover-fill-color: #174EA6;\n",
       "      --disabled-fill-color: #AAA;\n",
       "      --disabled-bg-color: #DDD;\n",
       "  }\n",
       "\n",
       "  [theme=dark] .colab-df-quickchart {\n",
       "      --bg-color: #3B4455;\n",
       "      --fill-color: #D2E3FC;\n",
       "      --hover-bg-color: #434B5C;\n",
       "      --hover-fill-color: #FFFFFF;\n",
       "      --disabled-bg-color: #3B4455;\n",
       "      --disabled-fill-color: #666;\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart {\n",
       "    background-color: var(--bg-color);\n",
       "    border: none;\n",
       "    border-radius: 50%;\n",
       "    cursor: pointer;\n",
       "    display: none;\n",
       "    fill: var(--fill-color);\n",
       "    height: 32px;\n",
       "    padding: 0;\n",
       "    width: 32px;\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart:hover {\n",
       "    background-color: var(--hover-bg-color);\n",
       "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "    fill: var(--button-hover-fill-color);\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart-complete:disabled,\n",
       "  .colab-df-quickchart-complete:disabled:hover {\n",
       "    background-color: var(--disabled-bg-color);\n",
       "    fill: var(--disabled-fill-color);\n",
       "    box-shadow: none;\n",
       "  }\n",
       "\n",
       "  .colab-df-spinner {\n",
       "    border: 2px solid var(--fill-color);\n",
       "    border-color: transparent;\n",
       "    border-bottom-color: var(--fill-color);\n",
       "    animation:\n",
       "      spin 1s steps(1) infinite;\n",
       "  }\n",
       "\n",
       "  @keyframes spin {\n",
       "    0% {\n",
       "      border-color: transparent;\n",
       "      border-bottom-color: var(--fill-color);\n",
       "      border-left-color: var(--fill-color);\n",
       "    }\n",
       "    20% {\n",
       "      border-color: transparent;\n",
       "      border-left-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "    }\n",
       "    30% {\n",
       "      border-color: transparent;\n",
       "      border-left-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "      border-right-color: var(--fill-color);\n",
       "    }\n",
       "    40% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "    }\n",
       "    60% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "    }\n",
       "    80% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "      border-bottom-color: var(--fill-color);\n",
       "    }\n",
       "    90% {\n",
       "      border-color: transparent;\n",
       "      border-bottom-color: var(--fill-color);\n",
       "    }\n",
       "  }\n",
       "</style>\n",
       "\n",
       "  <script>\n",
       "    async function quickchart(key) {\n",
       "      const quickchartButtonEl =\n",
       "        document.querySelector('#' + key + ' button');\n",
       "      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
       "      quickchartButtonEl.classList.add('colab-df-spinner');\n",
       "      try {\n",
       "        const charts = await google.colab.kernel.invokeFunction(\n",
       "            'suggestCharts', [key], {});\n",
       "      } catch (error) {\n",
       "        console.error('Error during call to suggestCharts:', error);\n",
       "      }\n",
       "      quickchartButtonEl.classList.remove('colab-df-spinner');\n",
       "      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
       "    }\n",
       "    (() => {\n",
       "      let quickchartButtonEl =\n",
       "        document.querySelector('#df-96849220-b3fb-4dbd-84f7-92217950240d button');\n",
       "      quickchartButtonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "    })();\n",
       "  </script>\n",
       "</div>\n",
       "    </div>\n",
       "  </div>\n"
      ],
      "text/plain": [
       "                                               document  \\\n",
       "0     Kickers on my watchlist XIDE TIT SOQ PNK CPW B...   \n",
       "1     user: AAP MOVIE. 55% return for the FEA/GEED i...   \n",
       "2     user I'd be afraid to short AMZN - they are lo...   \n",
       "3                                       MNTA Over 12.00   \n",
       "4                                         OI Over 21.37   \n",
       "...                                                 ...   \n",
       "5786  Industry body CII said #discoms are likely to ...   \n",
       "5787  #Gold prices slip below Rs 46,000 as #investor...   \n",
       "5788  Workers at Bajaj Auto have agreed to a 10% wag...   \n",
       "5789  #Sharemarket LIVE: Sensex off day’s high, up 6...   \n",
       "5790  #Sensex, #Nifty climb off day's highs, still u...   \n",
       "\n",
       "                                sentence_embedding_bert sentiment  \\\n",
       "0     [-0.9530864357948303, 0.2135828286409378, 0.10...  positive   \n",
       "1     [-0.4725969433784485, 0.5354134440422058, -0.2...  positive   \n",
       "2     [0.30400288105010986, 0.22862982749938965, -0....  positive   \n",
       "3     [-1.707902193069458, -0.48472753167152405, -0....  positive   \n",
       "4     [-2.3011534214019775, 0.2649511396884918, -0.4...  positive   \n",
       "...                                                 ...       ...   \n",
       "5786  [-0.21655204892158508, 0.6153537631034851, 0.0...  negative   \n",
       "5787  [-0.19915254414081573, 0.2607441842556, 0.0032...  negative   \n",
       "5788  [-0.4361518919467926, 0.9346759915351868, -0.3...  negative   \n",
       "5789  [-0.6081278920173645, 0.2732301354408264, 0.25...   neutral   \n",
       "5790  [-0.5274896621704102, 0.4326432943344116, 0.06...   neutral   \n",
       "\n",
       "     sentiment_confidence                                               text  \\\n",
       "0                     0.0  Kickers on my watchlist XIDE TIT SOQ PNK CPW B...   \n",
       "1                     0.0  user: AAP MOVIE. 55% return for the FEA/GEED i...   \n",
       "2                     0.0  user I'd be afraid to short AMZN - they are lo...   \n",
       "3                     0.0                                  MNTA Over 12.00     \n",
       "4                     0.0                                   OI  Over 21.37     \n",
       "...                   ...                                                ...   \n",
       "5786                  0.0  Industry body CII said #discoms are likely to ...   \n",
       "5787                  0.0  #Gold prices slip below Rs 46,000 as #investor...   \n",
       "5788                  0.0  Workers at Bajaj Auto have agreed to a 10% wag...   \n",
       "5789                  0.0  #Sharemarket LIVE: Sensex off day’s high, up 6...   \n",
       "5790                  0.0  #Sensex, #Nifty climb off day's highs, still u...   \n",
       "\n",
       "             y  \n",
       "0     positive  \n",
       "1     positive  \n",
       "2     positive  \n",
       "3     positive  \n",
       "4     positive  \n",
       "...        ...  \n",
       "5786  negative  \n",
       "5787  negative  \n",
       "5788  positive  \n",
       "5789  positive  \n",
       "5790  positive  \n",
       "\n",
       "[5791 rows x 6 columns]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainable_pipe = nlp.load('embed_sentence.bert train.sentiment')\n",
    "# We need to train longer and user smaller LR for NON-USE based sentence embeddings usually\n",
    "# We could tune the hyperparameters further with hyperparameter tuning methods like gridsearch\n",
    "# Also longer training gives more accuracy\n",
    "trainable_pipe['trainable_sentiment_dl'].setMaxEpochs(40)\n",
    "trainable_pipe['trainable_sentiment_dl'].setLr(0.0005)\n",
    "fitted_pipe = trainable_pipe.fit(train_df)\n",
    "# predict with the trainable pipeline on dataset and get predictions\n",
    "preds = fitted_pipe.predict(train_df,output_level='document')\n",
    "\n",
    "#sentence detector that is part of the pipe generates sone NaNs. lets drop them first\n",
    "preds.dropna(inplace=True)\n",
    "print(classification_report(preds['y'], preds['sentiment']))\n",
    "\n",
    "preds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2BB-NwZUoHSe"
   },
   "source": [
    "# 5. Lets save the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "eLex095goHwm"
   },
   "outputs": [],
   "source": [
    "stored_model_path = './models/classifier_dl_trained'\n",
    "fitted_pipe.save(stored_model_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "e_b2DPd4rCiU"
   },
   "source": [
    "# 6. Lets load the model from HDD.\n",
    "This makes Offlien NLU usage possible!   \n",
    "You need to call nlu.load(path=path_to_the_pipe) to load a model/pipeline from disk."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 133
    },
    "id": "SO4uz45MoRgp",
    "outputId": "f914162e-6c41-4355-ad11-20c92fd155d1"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning::Spark Session already created, some configs may not take.\n",
      "Warning::Spark Session already created, some configs may not take.\n",
      "Warning::Spark Session already created, some configs may not take.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "  <div id=\"df-ae828ef5-f1c2-485e-a1df-16b5d781c027\" class=\"colab-df-container\">\n",
       "    <div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>document</th>\n",
       "      <th>sentence_embedding_from_disk</th>\n",
       "      <th>sentiment</th>\n",
       "      <th>sentiment_confidence</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Tesla plans to invest 10M into the ML sector</td>\n",
       "      <td>[-0.07111673802137375, 0.9532930850982666, -1....</td>\n",
       "      <td>positive</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>\n",
       "    <div class=\"colab-df-buttons\">\n",
       "\n",
       "  <div class=\"colab-df-container\">\n",
       "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-ae828ef5-f1c2-485e-a1df-16b5d781c027')\"\n",
       "            title=\"Convert this dataframe to an interactive table.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
       "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
       "  </svg>\n",
       "    </button>\n",
       "\n",
       "  <style>\n",
       "    .colab-df-container {\n",
       "      display:flex;\n",
       "      gap: 12px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert {\n",
       "      background-color: #E8F0FE;\n",
       "      border: none;\n",
       "      border-radius: 50%;\n",
       "      cursor: pointer;\n",
       "      display: none;\n",
       "      fill: #1967D2;\n",
       "      height: 32px;\n",
       "      padding: 0 0 0 0;\n",
       "      width: 32px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert:hover {\n",
       "      background-color: #E2EBFA;\n",
       "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "      fill: #174EA6;\n",
       "    }\n",
       "\n",
       "    .colab-df-buttons div {\n",
       "      margin-bottom: 4px;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert {\n",
       "      background-color: #3B4455;\n",
       "      fill: #D2E3FC;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert:hover {\n",
       "      background-color: #434B5C;\n",
       "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "      fill: #FFFFFF;\n",
       "    }\n",
       "  </style>\n",
       "\n",
       "    <script>\n",
       "      const buttonEl =\n",
       "        document.querySelector('#df-ae828ef5-f1c2-485e-a1df-16b5d781c027 button.colab-df-convert');\n",
       "      buttonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "\n",
       "      async function convertToInteractive(key) {\n",
       "        const element = document.querySelector('#df-ae828ef5-f1c2-485e-a1df-16b5d781c027');\n",
       "        const dataTable =\n",
       "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
       "                                                    [key], {});\n",
       "        if (!dataTable) return;\n",
       "\n",
       "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
       "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
       "          + ' to learn more about interactive tables.';\n",
       "        element.innerHTML = '';\n",
       "        dataTable['output_type'] = 'display_data';\n",
       "        await google.colab.output.renderOutput(dataTable, element);\n",
       "        const docLink = document.createElement('div');\n",
       "        docLink.innerHTML = docLinkHtml;\n",
       "        element.appendChild(docLink);\n",
       "      }\n",
       "    </script>\n",
       "  </div>\n",
       "\n",
       "    </div>\n",
       "  </div>\n"
      ],
      "text/plain": [
       "                                       document  \\\n",
       "0  Tesla plans to invest 10M into the ML sector   \n",
       "\n",
       "                        sentence_embedding_from_disk sentiment  \\\n",
       "0  [-0.07111673802137375, 0.9532930850982666, -1....  positive   \n",
       "\n",
       "  sentiment_confidence  \n",
       "0                  0.0  "
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hdd_pipe = nlp.load(path=stored_model_path)\n",
    "\n",
    "preds = hdd_pipe.predict('Tesla plans to invest 10M into the ML sector')\n",
    "preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "e0CVlkk9v6Qi",
    "outputId": "20c1ef9d-6769-4a30-ebae-cb37c2ec9f1d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The following parameters are configurable for this NLU pipeline (You can copy paste the examples) :\n",
      ">>> component_list['document_assembler'] has settable params:\n",
      "component_list['document_assembler'].setCleanupMode('shrink')                                  | Info: possible values: disabled, inplace, inplace_full, shrink, shrink_full, each, each_full, delete_full | Currently set to : shrink\n",
      ">>> component_list['bert_sentence_embeddings@sent_small_bert_L2_128'] has settable params:\n",
      "component_list['bert_sentence_embeddings@sent_small_bert_L2_128'].setBatchSize(8)              | Info: Size of every batch | Currently set to : 8\n",
      "component_list['bert_sentence_embeddings@sent_small_bert_L2_128'].setCaseSensitive(False)      | Info: whether to ignore case in tokens for embeddings matching | Currently set to : False\n",
      "component_list['bert_sentence_embeddings@sent_small_bert_L2_128'].setDimension(128)            | Info: Number of embedding dimensions | Currently set to : 128\n",
      "component_list['bert_sentence_embeddings@sent_small_bert_L2_128'].setMaxSentenceLength(128)    | Info: Max sentence length to process | Currently set to : 128\n",
      "component_list['bert_sentence_embeddings@sent_small_bert_L2_128'].setEngine('tensorflow')      | Info: Deep Learning engine used for this model | Currently set to : tensorflow\n",
      "component_list['bert_sentence_embeddings@sent_small_bert_L2_128'].setIsLong(False)             | Info: Use Long type instead of Int type for inputs buffer - Some Bert models require Long instead of Int. | Currently set to : False\n",
      "component_list['bert_sentence_embeddings@sent_small_bert_L2_128'].setStorageRef('sent_small_bert_L2_128')  | Info: unique reference name for identification | Currently set to : sent_small_bert_L2_128\n",
      ">>> component_list['sentiment_dl@sent_small_bert_L2_128'] has settable params:\n",
      "component_list['sentiment_dl@sent_small_bert_L2_128'].setThreshold(0.6)                        | Info: The minimum threshold for the final result otheriwse it will be neutral | Currently set to : 0.6\n",
      "component_list['sentiment_dl@sent_small_bert_L2_128'].setThresholdLabel('neutral')             | Info: In case the score is less than threshold, what should be the label. Default is neutral. | Currently set to : neutral\n",
      "component_list['sentiment_dl@sent_small_bert_L2_128'].setEngine('tensorflow')                  | Info: Deep Learning engine used for this model | Currently set to : tensorflow\n",
      "component_list['sentiment_dl@sent_small_bert_L2_128'].setClasses(['positive', 'negative'])     | Info: get the tags used to trained this SentimentDLModel | Currently set to : ['positive', 'negative']\n",
      "component_list['sentiment_dl@sent_small_bert_L2_128'].setStorageRef('sent_small_bert_L2_128')  | Info: unique reference name for identification | Currently set to : sent_small_bert_L2_128\n"
     ]
    }
   ],
   "source": [
    "hdd_pipe.print_info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "73rQbUy-KLpb"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
